{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Fashion MNIST复习\n",
    "\n",
    "最新的TensorFlow已经使用fashion mnist数据集替代mnist数据集作为起始教程了。\n",
    "\n",
    "我最近学了斯坦福的理论基础，现在准备回过头来复习一遍基础，以此加深理解。\n",
    "\n",
    "以前的mnist数据集是手写数字的图片，现在的fashion mnist数据集是一些衣物的图片，但是两者图片的大小都是一致的，均为`28*28`。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集\n",
    "\n",
    "现在，先加载数据集："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import keras\n",
    "\n",
    "\n",
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "# 第一次会从网络下载数据集，目录在当前用户目录，以我为例，在/home/allen/.keras/datasets/fashion_mnist\n",
    "train_images, train_labels, test_images, test_labels = fashion_mnist.load_data()\n",
    "# train_images的形状为[60000, 28, 28]，代表60000张图片，每张图片为28*28像素\n",
    "# train_labels是一个长度为60000的列表，每个值代表对用train image的类别\n",
    "# test_images的形状是[10000, 28, 28]，test_labels是长度为10000的列表\n",
    "\n",
    "# 类别的字符描述，把train labels提供的数字类别作为索引，可以取出对应的类别字符描述\n",
    "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 看看图片是什么样的\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure()\n",
    "plt.imshow(train_images[0])\n",
    "plt.colorbar()\n",
    "plt.grid(False)\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "for i in range(25):\n",
    "    plt.subplot(5,5,i+1)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    plt.grid(False)\n",
    "    # 输出二值化图像\n",
    "    plt.imshow(train_images[i], cmap=plt.cm.binary)\n",
    "    plt.xlabel(class_names[train_labels[i]])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![fashion_mnist_train_image_0](images/fashion_mnist_train_image_0.png)\n",
    "![fashion_mnist_train_images_25](images/fashion_mnist_train_images_25.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建模型\n",
    "\n",
    "我们使用keras的sequential api构建模型:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model build.\n",
      "Output shape of flatten layer is:\n",
      "(None, 784)\n",
      "Output shape of dense layer is:\n",
      "(None, 128)\n",
      "Output shape of output layer is:\n",
      "(None, 10)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import tensorflow as tf\n",
    "\n",
    "# 使用keras的序贯模型\n",
    "model = keras.Sequential()\n",
    "# 将`28*28`的二维图像，展平成一维\n",
    "flatten = keras.layers.Flatten(input_shape=(28, 28))\n",
    "model.add(flatten)\n",
    "\n",
    "# 添加全连接层\n",
    "dense = keras.layers.Dense(128, activation=keras.activations.relu)\n",
    "model.add(dense)\n",
    "\n",
    "# 添加输出层，使用softmax函数进行多分类，获得每个类别的概率\n",
    "output = keras.layers.Dense(10, activation=keras.activations.softmax)\n",
    "model.add(output)\n",
    "\n",
    "# 编译模型\n",
    "# 1.使用adam作为优化器\n",
    "# 2.使用loss=keras.losses.sparse_categorical_crossentropy作为损失，\n",
    "#  如果使用tensorflow作为后端，它实际上就是tf.nn.sparse_softmax_cross_entropy_with_logits，\n",
    "#  因为我们的输出是整数，所以使用这个loss\n",
    "# 3.metrics我们使用accuracy，用来统计正确率\n",
    "model.compile(optimizer=keras.optimizers.Adam(),\n",
    "              loss=keras.losses.sparse_categorical_crossentropy,\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "print(\"Model build.\")\n",
    "print(\"Output shape of flatten layer is:\")\n",
    "print(flatten.output_shape)\n",
    "\n",
    "print(\"Output shape of dense layer is:\")\n",
    "print(dense.output_shape)\n",
    "\n",
    "print(\"Output shape of output layer is:\")\n",
    "print(output.output_shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "关于上面的loss，还需要多说一句。\n",
    "tf.nn.sparse_softmax_cross_entropy_with_logits的参数需要的是logits，而keras.losses.sparse_categorical_crossentropy的参数需要的是probabilities。keras.losses.sparse_categorical_crossentropy有个参数from_logits，如果是False(默认值)，就需要对第一个输入参数进行一次对数运算，然后reshape，转化成logits。\n",
    "\n",
    "```python\n",
    "\n",
    "output = tf.log(output)\n",
    "output_shape = output.get_shape()\n",
    "logits　＝　tf.reshape(output, [-1, int(output_shape[-1])])\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "使用keras训练模型很简单：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用train_images和对应的标签train_labels训练5个轮回\n",
    "model.fit(train_images, train_labels, epochs=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 估算模型的准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(test_images, test_labels)\n",
    "print('Test loss:',test_loss)\n",
    "print('Test accuracy:', test_acc)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测\n",
    "\n",
    "我们使用test_images作为输入，预测分类：\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "predictions = model.predict(test_images)\n",
    "predict_num_classes = []\n",
    "for i in range(len(test_labels)):\n",
    "    # 使用argmax操作，从softmax的输出概率中，选择最有可能的类别\n",
    "    prediction = np.argmax(predictions[i])\n",
    "    print(prediction)\n",
    "    predict_num_classes.append(prediction)\n",
    "\n",
    "print(predict_num_classes)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
